{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D:\\VirtualProject\\Python37Env\\torch_py38\\Scripts\\pip.exe\n",
      "D:\\python37\\Scripts\\pip.exe\n"
     ]
    }
   ],
   "source": [
    "! where pip\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:27:36.351187600Z",
     "start_time": "2025-09-10T08:27:35.834664700Z"
    }
   },
   "id": "3a3798848c5c7f09"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 参考的网站\n",
    "https://www.kaggle.com/code/arunmohan003/transformer-from-scratch-using-pytorch"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "98aa183082170624"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:29:18.179263Z",
     "start_time": "2025-09-10T08:29:00.156292Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.2+cpu\n"
     ]
    }
   ],
   "source": [
    "# importing required libraries\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import math,copy,re\n",
    "import warnings\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torchtext\n",
    "import matplotlib.pyplot as plt\n",
    "warnings.simplefilter(\"ignore\")\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "class Embedding(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            vocab_size: size of vocabulary\n",
    "            embed_dim: dimension of embeddings\n",
    "        \"\"\"\n",
    "        super(Embedding, self).__init__()\n",
    "        self.embed = nn.Embedding(vocab_size, embed_dim)\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: input vector\n",
    "        Returns:\n",
    "            out: embedding vector\n",
    "        \"\"\"\n",
    "        out = self.embed(x)\n",
    "        return out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:30:06.016737500Z",
     "start_time": "2025-09-10T08:30:05.983912100Z"
    }
   },
   "id": "fba3424538ed09f1"
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "# register buffer in Pytorch ->\n",
    "# If you have parameters in your model, which should be saved and restored in the state_dict,\n",
    "# but not trained by the optimizer, you should register them as buffers.\n",
    "\n",
    "\n",
    "class PositionalEmbedding(nn.Module):\n",
    "    def __init__(self,max_seq_len,embed_model_dim):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            seq_len: length of input sequence\n",
    "            embed_model_dim: demension of embedding\n",
    "        \"\"\"\n",
    "        super(PositionalEmbedding, self).__init__()\n",
    "        self.embed_dim = embed_model_dim\n",
    "\n",
    "        pe = torch.zeros(max_seq_len,self.embed_dim)\n",
    "        for pos in range(max_seq_len):\n",
    "            for i in range(0,self.embed_dim,2):\n",
    "                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))\n",
    "                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: input vector\n",
    "        Returns:\n",
    "            x: output\n",
    "        \"\"\"\n",
    "      \n",
    "        # make embeddings relatively larger\n",
    "        x = x * math.sqrt(self.embed_dim)\n",
    "        #add constant to embedding\n",
    "        seq_len = x.size(1)\n",
    "        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:35:01.763318300Z",
     "start_time": "2025-09-10T08:35:01.734303Z"
    }
   },
   "id": "3447f7ee6eef123a"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, embed_dim=512, n_heads=8):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            embed_dim: dimension of embeding vector output\n",
    "            n_heads: number of self attention heads\n",
    "        \"\"\"\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "\n",
    "        self.embed_dim = embed_dim    #512 dim\n",
    "        self.n_heads = n_heads   #8\n",
    "        self.single_head_dim = int(self.embed_dim / self.n_heads)   #512/8 = 64  . each key,query, value will be of 64d\n",
    "       \n",
    "        #key,query and value matrixes    #64 x 64   \n",
    "        self.query_matrix = nn.Linear(self.single_head_dim , self.single_head_dim ,bias=False)  # single key matrix for all 8 keys #512x512\n",
    "        self.key_matrix = nn.Linear(self.single_head_dim  , self.single_head_dim, bias=False)\n",
    "        self.value_matrix = nn.Linear(self.single_head_dim ,self.single_head_dim , bias=False)\n",
    "        self.out = nn.Linear(self.n_heads*self.single_head_dim ,self.embed_dim) \n",
    "\n",
    "    def forward(self,key,query,value,mask=None):    #batch_size x sequence_length x embedding_dim    # 32 x 10 x 512\n",
    "        \n",
    "        \"\"\"\n",
    "        Args:\n",
    "           key : key vector\n",
    "           query : query vector\n",
    "           value : value vector\n",
    "           mask: mask for decoder\n",
    "        \n",
    "        Returns:\n",
    "           output vector from multihead attention\n",
    "        \"\"\"\n",
    "        batch_size = key.size(0)\n",
    "        seq_length = key.size(1)\n",
    "        \n",
    "        # query dimension can change in decoder during inference. \n",
    "        # so we cant take general seq_length\n",
    "        seq_length_query = query.size(1)\n",
    "        \n",
    "        # 32x10x512\n",
    "        key = key.view(batch_size, seq_length, self.n_heads, self.single_head_dim)  #batch_size x sequence_length x n_heads x single_head_dim = (32x10x8x64)\n",
    "        query = query.view(batch_size, seq_length_query, self.n_heads, self.single_head_dim) #(32x10x8x64)\n",
    "        value = value.view(batch_size, seq_length, self.n_heads, self.single_head_dim) #(32x10x8x64)\n",
    "       \n",
    "        k = self.key_matrix(key)       # (32x10x8x64)\n",
    "        q = self.query_matrix(query)   \n",
    "        v = self.value_matrix(value)\n",
    "\n",
    "        q = q.transpose(1,2)  # (batch_size, n_heads, seq_len, single_head_dim)    # (32 x 8 x 10 x 64)\n",
    "        k = k.transpose(1,2)  # (batch_size, n_heads, seq_len, single_head_dim)\n",
    "        v = v.transpose(1,2)  # (batch_size, n_heads, seq_len, single_head_dim)\n",
    "       \n",
    "        # computes attention\n",
    "        # adjust key for matrix multiplication\n",
    "        k_adjusted = k.transpose(-1,-2)  #(batch_size, n_heads, single_head_dim, seq_ken)  #(32 x 8 x 64 x 10)\n",
    "        product = torch.matmul(q, k_adjusted)  #(32 x 8 x 10 x 64) x (32 x 8 x 64 x 10) = #(32x8x10x10)\n",
    "      \n",
    "        \n",
    "        # fill those positions of product matrix as (-1e20) where mask positions are 0\n",
    "        if mask is not None:\n",
    "             product = product.masked_fill(mask == 0, float(\"-1e20\"))\n",
    "\n",
    "        #divising by square root of key dimension\n",
    "        product = product / math.sqrt(self.single_head_dim) # / sqrt(64)\n",
    "\n",
    "        #applying softmax\n",
    "        scores = F.softmax(product, dim=-1)\n",
    " \n",
    "        #mutiply with value matrix\n",
    "        scores = torch.matmul(scores, v)  ##(32x8x 10x 10) x (32 x 8 x 10 x 64) = (32 x 8 x 10 x 64) \n",
    "        \n",
    "        #concatenated output\n",
    "        concat = scores.transpose(1,2).contiguous().view(batch_size, seq_length_query, self.single_head_dim*self.n_heads)  # (32x8x10x64) -> (32x10x8x64)  -> (32,10,512)\n",
    "        \n",
    "        output = self.out(concat) #(32,10,512) -> (32,10,512)\n",
    "       \n",
    "        return output"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:46:14.931575700Z",
     "start_time": "2025-09-10T08:46:14.911578600Z"
    }
   },
   "id": "af12485f402ee37"
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):\n",
    "        super(TransformerBlock, self).__init__()\n",
    "        \n",
    "        \"\"\"\n",
    "        Args:\n",
    "           embed_dim: dimension of the embedding\n",
    "           expansion_factor: fator ehich determines output dimension of linear layer\n",
    "           n_heads: number of attention heads\n",
    "        \n",
    "        \"\"\"\n",
    "        self.attention = MultiHeadAttention(embed_dim, n_heads)\n",
    "        \n",
    "        self.norm1 = nn.LayerNorm(embed_dim)\n",
    "        self.norm2 = nn.LayerNorm(embed_dim)\n",
    "        \n",
    "        self.feed_forward = nn.Sequential(\n",
    "                          nn.Linear(embed_dim, expansion_factor*embed_dim),\n",
    "                          nn.ReLU(),\n",
    "                          nn.Linear(expansion_factor*embed_dim, embed_dim)\n",
    "        )\n",
    "\n",
    "        self.dropout1 = nn.Dropout(0.2)\n",
    "        self.dropout2 = nn.Dropout(0.2)\n",
    "\n",
    "    def forward(self,key,query,value):\n",
    "        \n",
    "        \"\"\"\n",
    "        Args:\n",
    "           key: key vector\n",
    "           query: query vector\n",
    "           value: value vector\n",
    "           norm2_out: output of transformer block\n",
    "        \n",
    "        \"\"\"\n",
    "        \n",
    "        attention_out = self.attention(key,query,value)  #32x10x512\n",
    "        attention_residual_out = attention_out + value  #32x10x512\n",
    "        norm1_out = self.dropout1(self.norm1(attention_residual_out)) #32x10x512\n",
    "\n",
    "        feed_fwd_out = self.feed_forward(norm1_out) #32x10x512 -> #32x10x2048 -> 32x10x512\n",
    "        feed_fwd_residual_out = feed_fwd_out + norm1_out #32x10x512\n",
    "        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out)) #32x10x512\n",
    "\n",
    "        return norm2_out\n",
    "\n",
    "\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        seq_len : length of input sequence\n",
    "        embed_dim: dimension of embedding\n",
    "        num_layers: number of encoder layers\n",
    "        expansion_factor: factor which determines number of linear layers in feed forward layer\n",
    "        n_heads: number of heads in multihead attention\n",
    "        \n",
    "    Returns:\n",
    "        out: output of the encoder\n",
    "    \"\"\"\n",
    "    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, n_heads=8):\n",
    "        super(TransformerEncoder, self).__init__()\n",
    "        \n",
    "        self.embedding_layer = Embedding(vocab_size, embed_dim)\n",
    "        self.positional_encoder = PositionalEmbedding(seq_len, embed_dim)\n",
    "\n",
    "        self.layers = nn.ModuleList([TransformerBlock(embed_dim, expansion_factor, n_heads) for i in range(num_layers)])\n",
    "    \n",
    "    def forward(self, x):\n",
    "        embed_out = self.embedding_layer(x)\n",
    "        out = self.positional_encoder(embed_out)\n",
    "        for layer in self.layers:\n",
    "            out = layer(out,out,out)\n",
    "\n",
    "        return out  #32x10x512"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T08:59:01.182780600Z",
     "start_time": "2025-09-10T08:59:01.085779600Z"
    }
   },
   "id": "7bc7cf57237295f1"
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "class DecoderBlock(nn.Module):\n",
    "    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):\n",
    "        super(DecoderBlock, self).__init__()\n",
    "\n",
    "        \"\"\"\n",
    "        Args:\n",
    "           embed_dim: dimension of the embedding\n",
    "           expansion_factor: fator ehich determines output dimension of linear layer\n",
    "           n_heads: number of attention heads\n",
    "        \n",
    "        \"\"\"\n",
    "        self.attention = MultiHeadAttention(embed_dim, n_heads=8)\n",
    "        self.norm = nn.LayerNorm(embed_dim)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "        self.transformer_block = TransformerBlock(embed_dim, expansion_factor, n_heads)\n",
    "        \n",
    "    \n",
    "    def forward(self, key, query, x,mask):\n",
    "        \n",
    "        \"\"\"\n",
    "        Args:\n",
    "           key: key vector\n",
    "           query: query vector\n",
    "           value: value vector\n",
    "           mask: mask to be given for multi head attention \n",
    "        Returns:\n",
    "           out: output of transformer block\n",
    "    \n",
    "        \"\"\"\n",
    "        \n",
    "        #we need to pass mask mask only to fst attention\n",
    "        attention = self.attention(x,x,x,mask=mask) #32x10x512\n",
    "        value = self.dropout(self.norm(attention + x))\n",
    "        \n",
    "        out = self.transformer_block(key, query, value)\n",
    "\n",
    "        \n",
    "        return out\n",
    "\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, target_vocab_size, embed_dim, seq_len, num_layers=2, expansion_factor=4, n_heads=8):\n",
    "        super(TransformerDecoder, self).__init__()\n",
    "        \"\"\"  \n",
    "        Args:\n",
    "           target_vocab_size: vocabulary size of taget\n",
    "           embed_dim: dimension of embedding\n",
    "           seq_len : length of input sequence\n",
    "           num_layers: number of encoder layers\n",
    "           expansion_factor: factor which determines number of linear layers in feed forward layer\n",
    "           n_heads: number of heads in multihead attention\n",
    "        \n",
    "        \"\"\"\n",
    "        self.word_embedding = nn.Embedding(target_vocab_size, embed_dim)\n",
    "        self.position_embedding = PositionalEmbedding(seq_len, embed_dim)\n",
    "\n",
    "        self.layers = nn.ModuleList(\n",
    "            [\n",
    "                DecoderBlock(embed_dim, expansion_factor=4, n_heads=8) \n",
    "                for _ in range(num_layers)\n",
    "            ]\n",
    "\n",
    "        )\n",
    "        self.fc_out = nn.Linear(embed_dim, target_vocab_size)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "\n",
    "    def forward(self, x, enc_out, mask):\n",
    "        \n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: input vector from target\n",
    "            enc_out : output from encoder layer\n",
    "            trg_mask: mask for decoder self attention\n",
    "        Returns:\n",
    "            out: output vector\n",
    "        \"\"\"\n",
    "            \n",
    "        \n",
    "        x = self.word_embedding(x)  #32x10x512\n",
    "        x = self.position_embedding(x) #32x10x512\n",
    "        x = self.dropout(x)\n",
    "     \n",
    "        for layer in self.layers:\n",
    "            x = layer(enc_out, x, enc_out, mask) \n",
    "\n",
    "        out = F.softmax(self.fc_out(x))\n",
    "\n",
    "        return out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T09:02:15.891000500Z",
     "start_time": "2025-09-10T09:02:15.864999600Z"
    }
   },
   "id": "b42950e539d7573f"
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, seq_length,num_layers=2, expansion_factor=4, n_heads=8):\n",
    "        super(Transformer, self).__init__()\n",
    "        \n",
    "        \"\"\"  \n",
    "        Args:\n",
    "           embed_dim:  dimension of embedding \n",
    "           src_vocab_size: vocabulary size of source\n",
    "           target_vocab_size: vocabulary size of target\n",
    "           seq_length : length of input sequence\n",
    "           num_layers: number of encoder layers\n",
    "           expansion_factor: factor which determines number of linear layers in feed forward layer\n",
    "           n_heads: number of heads in multihead attention\n",
    "        \n",
    "        \"\"\"\n",
    "        \n",
    "        self.target_vocab_size = target_vocab_size\n",
    "\n",
    "        self.encoder = TransformerEncoder(seq_length, src_vocab_size, embed_dim, num_layers=num_layers, expansion_factor=expansion_factor, n_heads=n_heads)\n",
    "        self.decoder = TransformerDecoder(target_vocab_size, embed_dim, seq_length, num_layers=num_layers, expansion_factor=expansion_factor, n_heads=n_heads)\n",
    "        \n",
    "    \n",
    "    def make_trg_mask(self, trg):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            trg: target sequence\n",
    "        Returns:\n",
    "            trg_mask: target mask\n",
    "        \"\"\"\n",
    "        batch_size, trg_len = trg.shape\n",
    "        # returns the lower triangular part of matrix filled with ones\n",
    "        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(\n",
    "            batch_size, 1, trg_len, trg_len\n",
    "        )\n",
    "        return trg_mask    \n",
    "\n",
    "    def decode(self,src,trg):\n",
    "        \"\"\"\n",
    "        for inference\n",
    "        Args:\n",
    "            src: input to encoder \n",
    "            trg: input to decoder\n",
    "        out:\n",
    "            out_labels : returns final prediction of sequence\n",
    "        \"\"\"\n",
    "        trg_mask = self.make_trg_mask(trg)\n",
    "        enc_out = self.encoder(src)\n",
    "        out_labels = []\n",
    "        batch_size,seq_len = src.shape[0],src.shape[1]\n",
    "        #outputs = torch.zeros(seq_len, batch_size, self.target_vocab_size)\n",
    "        out = trg\n",
    "        for i in range(seq_len): #10\n",
    "            out = self.decoder(out,enc_out,trg_mask) #bs x seq_len x vocab_dim\n",
    "            # taking the last token\n",
    "            out = out[:,-1,:]\n",
    "     \n",
    "            out = out.argmax(-1)\n",
    "            out_labels.append(out.item())\n",
    "            out = torch.unsqueeze(out,axis=0)\n",
    "          \n",
    "        \n",
    "        return out_labels\n",
    "    \n",
    "    def forward(self, src, trg):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            src: input to encoder \n",
    "            trg: input to decoder\n",
    "        out:\n",
    "            out: final vector which returns probabilities of each target word\n",
    "        \"\"\"\n",
    "        trg_mask = self.make_trg_mask(trg)\n",
    "        enc_out = self.encoder(src)\n",
    "   \n",
    "        outputs = self.decoder(trg, enc_out, trg_mask)\n",
    "        return outputs"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T09:07:50.555828400Z",
     "start_time": "2025-09-10T09:07:50.522828100Z"
    }
   },
   "id": "aebd140947aa1402"
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 12]) torch.Size([2, 12])\n"
     ]
    },
    {
     "data": {
      "text/plain": "Transformer(\n  (encoder): TransformerEncoder(\n    (embedding_layer): Embedding(\n      (embed): Embedding(11, 512)\n    )\n    (positional_encoder): PositionalEmbedding()\n    (layers): ModuleList(\n      (0-5): 6 x TransformerBlock(\n        (attention): MultiHeadAttention(\n          (query_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (key_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (value_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (out): Linear(in_features=512, out_features=512, bias=True)\n        )\n        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n        (feed_forward): Sequential(\n          (0): Linear(in_features=512, out_features=2048, bias=True)\n          (1): ReLU()\n          (2): Linear(in_features=2048, out_features=512, bias=True)\n        )\n        (dropout1): Dropout(p=0.2, inplace=False)\n        (dropout2): Dropout(p=0.2, inplace=False)\n      )\n    )\n  )\n  (decoder): TransformerDecoder(\n    (word_embedding): Embedding(11, 512)\n    (position_embedding): PositionalEmbedding()\n    (layers): ModuleList(\n      (0-5): 6 x DecoderBlock(\n        (attention): MultiHeadAttention(\n          (query_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (key_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (value_matrix): Linear(in_features=64, out_features=64, bias=False)\n          (out): Linear(in_features=512, out_features=512, bias=True)\n        )\n        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n        (dropout): Dropout(p=0.2, inplace=False)\n        (transformer_block): TransformerBlock(\n          (attention): MultiHeadAttention(\n            (query_matrix): Linear(in_features=64, out_features=64, bias=False)\n            (key_matrix): Linear(in_features=64, out_features=64, bias=False)\n            (value_matrix): Linear(in_features=64, out_features=64, bias=False)\n            (out): Linear(in_features=512, out_features=512, bias=True)\n          )\n          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n          (feed_forward): Sequential(\n            (0): Linear(in_features=512, out_features=2048, bias=True)\n            (1): ReLU()\n            (2): Linear(in_features=2048, out_features=512, bias=True)\n          )\n          (dropout1): Dropout(p=0.2, inplace=False)\n          (dropout2): Dropout(p=0.2, inplace=False)\n        )\n      )\n    )\n    (fc_out): Linear(in_features=512, out_features=11, bias=True)\n    (dropout): Dropout(p=0.2, inplace=False)\n  )\n)"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src_vocab_size = 11\n",
    "target_vocab_size = 11\n",
    "num_layers = 6\n",
    "seq_length= 12\n",
    "\n",
    "\n",
    "# let 0 be sos token and 1 be eos token\n",
    "src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], \n",
    "                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])\n",
    "target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], \n",
    "                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])\n",
    "\n",
    "print(src.shape,target.shape)\n",
    "model = Transformer(embed_dim=512, src_vocab_size=src_vocab_size, \n",
    "                    target_vocab_size=target_vocab_size, seq_length=seq_length,\n",
    "                    num_layers=num_layers, expansion_factor=4, n_heads=8)\n",
    "model"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T09:08:12.852228400Z",
     "start_time": "2025-09-10T09:08:12.426742800Z"
    }
   },
   "id": "b48395ea18d6c6aa"
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([2, 12, 11])"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = model(src, target)\n",
    "out.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T09:08:32.372831600Z",
     "start_time": "2025-09-10T09:08:32.177832300Z"
    }
   },
   "id": "66ab2b8098026e01"
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 12]) torch.Size([1, 1])\n"
     ]
    },
    {
     "data": {
      "text/plain": "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# inference\n",
    "model = Transformer(embed_dim=512, src_vocab_size=src_vocab_size, \n",
    "                    target_vocab_size=target_vocab_size, seq_length=seq_length, \n",
    "                    num_layers=num_layers, expansion_factor=4, n_heads=8)\n",
    "                  \n",
    "\n",
    "\n",
    "src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1]])\n",
    "trg = torch.tensor([[0]])\n",
    "print(src.shape,trg.shape)\n",
    "out = model.decode(src, trg)\n",
    "out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-09-10T09:09:10.800213500Z",
     "start_time": "2025-09-10T09:09:09.312704Z"
    }
   },
   "id": "aee2c4d059a9c7df"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "55938bea41da0400"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
